// Copyright 2025 The LUCI Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Largely generated by Gemini 2.5 Pro and the following prompt:
//
// Write a Go program to inspect the AST of a source file and report lines
// containing `errors.Annotate(<var>, ...)` where <var> is not statically known
// to be nil (for example - where the line is not guarded with `if <var> !=
// nil`, or the equivalent).
//
// Use `golang.org/x/tools/go/ast/astutil` if additional ast utility functions
// are needed.
//
// ... output ...
//
// Fix the code to account for the fact that the guard for the variable may
// happen anywhere earlier in the `path` between the definition of the variable
// and the actual Annotate callsite.
//
// ... output ...
//
// Add debugging prints to this so I can figure out where this logic breaks
//
// ... output ...
//
// Remove the debugging prints and use golang.org/x/tools/go/ast/astutil to walk
// the tree instead - the bug is in the deferred function in the callback to
// ast.Inspect - this defer happens when the callback exits, but is meant to
// happen after the node is fully processed.
//
// ... output ...
//
// astutil.Cursor.Path does not exist - instead manually track the ancestors as
// you did before, but use the post-order function to pop the ancestors off
//
// ... output ...
//
// Fix the code to account for:
// * compound conditionals (e.g. `... && err != nil && ...`)
// * switch case statements checking for `err != nil`
//
// (needed manual bugfixing for hallucinated TypeCaseClause)

package main

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	"log"
	"os"
	"path/filepath"
	"strings"

	"golang.org/x/tools/go/ast/astutil"
)

func main() {
	if len(os.Args) < 2 {
		log.Fatalf("Usage: %s <file.go>", filepath.Base(os.Args[0]))
	}
	filePath := os.Args[1]

	fset := token.NewFileSet()
	fileNode, err := parser.ParseFile(fset, filePath, nil, parser.ParseComments)
	if err != nil {
		log.Fatalf("Failed to parse file %s: %v", filePath, err)
	}

	errorsPkgName := getErrorsPackageName(fileNode)

	var ancestors []ast.Node // Manually tracked ancestor stack

	// Pre-order function
	pre := func(c *astutil.Cursor) bool {
		n := c.Node()
		if n == nil {
			return true
		}

		callExpr, ok := n.(*ast.CallExpr)
		if ok {
			if isErrorsAnnotateCall(callExpr, errorsPkgName) {
				if len(callExpr.Args) > 0 {
					position := fset.Position(callExpr.Lparen)
					if errVarIdent, okIdent := callExpr.Args[0].(*ast.Ident); okIdent {
						// Don't flag errors.Annotate(ErrSomething, ...) because these are
						// sentinels.
						if !strings.HasPrefix(errVarIdent.Name, "Err") {
							// Pass the current `ancestors` slice (parents of callExpr)
							if !isVarGuarded(errorsPkgName, ancestors, errVarIdent.Name, callExpr) {
								fmt.Printf("%s:%d: call to %s.Annotate with variable '%s' appears unguarded\n",
									position.Filename, position.Line, errorsPkgName, errVarIdent.Name)
							}
						}
					} else {
						fmt.Printf("%s:%d: call to %s.Annotate with non-variable appears unguarded\n",
							position.Filename, position.Line, errorsPkgName)
					}
				}
			}
		}

		ancestors = append(ancestors, n)
		return true
	}

	// Post-order function
	post := func(c *astutil.Cursor) bool {
		if c.Node() == nil {
			return true
		}
		if len(ancestors) > 0 {
			ancestors = ancestors[:len(ancestors)-1]
		}
		return true
	}

	astutil.Apply(fileNode, pre, post)
}

func getErrorsPackageName(fileNode *ast.File) string {
	for _, imp := range fileNode.Imports {
		if imp.Path.Value == "\"go.chromium.org/luci/common/errors\"" {
			if imp.Name != nil {
				return imp.Name.Name
			}
			return "errors"
		}
	}
	return "errors"
}

func isErrorsAnnotateCall(callExpr *ast.CallExpr, errorsPkgName string) bool {
	selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
	if !ok {
		return false
	}
	pkgIdent, ok := selExpr.X.(*ast.Ident)
	if !ok {
		return false
	}
	return pkgIdent.Name == errorsPkgName && selExpr.Sel.Name == "Annotate"
}

func isErrorsIsCall(callExpr *ast.CallExpr, errorsPkgName string) bool {
	selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
	if !ok {
		return false
	}
	pkgIdent, ok := selExpr.X.(*ast.Ident)
	if !ok {
		return false
	}
	return pkgIdent.Name == errorsPkgName && selExpr.Sel.Name == "Is"
}

func isVarGuarded(errorsPkgName string, currentAncestors []ast.Node, varName string, callExpr *ast.CallExpr) bool {
	callPos := callExpr.Pos()

	// Phase 1: Check for direct nesting guard (IfStmt, SwitchStmt, TypeSwitchStmt)
	for i := len(currentAncestors) - 1; i >= 0; i-- {
		node := currentAncestors[i]
		if ifStmt, ok := node.(*ast.IfStmt); ok {
			if ifStmt.Body != nil && nodeContains(ifStmt.Body, callPos) {
				if checkCondition(errorsPkgName, ifStmt.Cond, varName, true) { // Expect varName != nil
					return true
				}
			}
			if ifStmt.Else != nil && nodeContains(ifStmt.Else, callPos) {
				if checkCondition(errorsPkgName, ifStmt.Cond, varName, false) { // Expect varName == nil (so else means varName != nil context)
					return true
				}
			}
		} else if switchStmt, ok := node.(*ast.SwitchStmt); ok {
			// note that switchStmt.Tag==nil (tagless switch statement) will return isTagIdent == false
			tagIdentName, isTagIdent := getIdentName(switchStmt.Tag)
			if isTagIdent && tagIdentName != varName {
				// skip this switch statement in the ancestry chain, it doesn't directly
				// concern our error var.
				continue
			}

			for _, stmt := range switchStmt.Body.List { // Iterate over *ast.CaseClause
				caseClause, ok := stmt.(*ast.CaseClause)
				if !ok {
					continue // not possible
				}

				if !nodeContains(caseClause, callPos) {
					// Call is NOT in this case's statements, but we can check if it
					// has a condition which would handle the error, e.g.
					//
					// switch {
					//   case err == nil:  // < checking this
					//   default:
					//     // our call
					// }
					if switchStmt.Tag == nil {
						// Tagless switch (switch { case cond: ... })
						for _, caseCond := range caseClause.List { // case expr1, expr2:
							if checkCondition(errorsPkgName, caseCond, varName, false) { // false means we expect varName == nil
								return true // Call is guarded
							}
						}
					} else {
						// Tagged switch (switch err { case value: ... })
						for _, caseExpr := range caseClause.List {
							if name, isIdent := getIdentName(caseExpr); isIdent && name == "nil" {
								return true // if err was nil, it would match this case, not our callsite
							}
						}

						// don't need to check default - default is the last case, and
						// doesn't have our call in it, so it doesn't matter.
					}
				} else {
					// Call IS in this caseClause.
					if switchStmt.Tag == nil {
						// Tagless switch (switch { case cond: ... })
						for _, caseCond := range caseClause.List { // case expr1, expr2:
							if checkCondition(errorsPkgName, caseCond, varName, true) { // true means we expect varName != nil
								return true // Guarded
							}
						}
					} else {
						// Tagged switch (switch err { case value: ... })
						for _, caseExpr := range caseClause.List {
							if name, isIdent := getIdentName(caseExpr); isIdent {
								if name == "nil" {
									break // keep looking up the stack
								}
								if strings.HasPrefix(name, "err") || strings.HasPrefix(name, "Err") {
									return true // assume this is a sentinel error and non-nil
								}
							}
						}

						// don't need to check default - default doesn't have any condition
						// associated with it, and we know this contains our call.
						//
						// If one of the previous statements guarded our call, we would
						// already have returned.
					}

					// If the call is in this case, we've analyzed it. No need to check other cases for this call.
					break
				}
			}
		} else if typeSwitchStmt, ok := node.(*ast.TypeSwitchStmt); ok {
			var switchedVarName string
			// Check if varName is the variable assigned by the type switch, e.g., v in v := x.(type)
			if assignStmt, isAssign := typeSwitchStmt.Assign.(*ast.AssignStmt); isAssign {
				if len(assignStmt.Lhs) == 1 {
					if ident, isIdent := assignStmt.Lhs[0].(*ast.Ident); isIdent {
						switchedVarName = ident.Name
					}
				}
			}
			// If errors.Annotate(varName,...) uses the variable from the type switch.
			if switchedVarName == varName {
				for _, stmt := range typeSwitchStmt.Body.List {
					caseClause, ok := stmt.(*ast.CaseClause)
					if !ok {
						continue
					}

					if !nodeContains(caseClause, callPos) {
						continue // Call is not in this type case clause
					}

					// Call is in this CaseClause.
					isNilTypeCase := false
					if caseClause.List != nil { // Not default, check type expressions
						for _, typeExpr := range caseClause.List {
							// Check if one of the types in `case typeA, typeB:` is the identifier `nil`
							if ident, isIdent := typeExpr.(*ast.Ident); isIdent && ident.Name == "nil" {
								isNilTypeCase = true
								break
							}
						}
					}

					if !isNilTypeCase {
						// If it's 'default' or a case with specific non-nil types (not 'case nil:').
						// Then 'varName' (the v from v:= subject.(type)) is of a specific non-nil type.
						return true // Guarded
					}
					// If it is 'case nil:', it's not guarded for 'varName != nil'.
					break // Analyzed the type case clause containing the call.
				}
			}
		}

		if _, isFunc := node.(*ast.FuncDecl); isFunc {
			break // Stop phase 1 at function boundary
		}
	}

	// Phase 2: Check for preceding "early exit" guards in the same block as callExpr.
	if len(currentAncestors) > 0 {
		immediateParentNode := currentAncestors[len(currentAncestors)-1]
		var stmtListToCheck []ast.Stmt

		switch p := immediateParentNode.(type) {
		case *ast.BlockStmt:
			stmtListToCheck = p.List
		case *ast.CaseClause: // Body of a case in a switch
			stmtListToCheck = p.Body
		case *ast.CommClause: // Body of a case in a select
			stmtListToCheck = p.Body
		}

		if stmtListToCheck != nil {
			if checkStmtsForEarlyExitGuard(errorsPkgName, stmtListToCheck, varName, callExpr) {
				return true
			}
		}
	}
	return false
}

func checkStmtsForEarlyExitGuard(errorsPkgName string, stmts []ast.Stmt, varName string, callNode ast.Node) bool {
	callPos := callNode.Pos()
	for _, stmtInList := range stmts {
		if stmtInList.End() >= callPos {
			break
		}
		if ifStmt, ok := stmtInList.(*ast.IfStmt); ok {
			// Pattern 1: if varName == nil { /* body always exits */ }
			if checkCondition(errorsPkgName, ifStmt.Cond, varName, false) {
				if ifStmt.Body != nil && stmtAlwaysExits(ifStmt.Body) {
					return true
				}
			}
			// Pattern 2: if varName != nil { /* no exit */ } else { /* else body ALWAYS exits */ }
			if checkCondition(errorsPkgName, ifStmt.Cond, varName, true) {
				if ifStmt.Else != nil && stmtAlwaysExits(ifStmt.Else) {
					return true
				}
			}
		}
	}
	return false
}

func stmtAlwaysExits(stmt ast.Node) bool {
	switch s := stmt.(type) {
	case *ast.BlockStmt:
		if len(s.List) == 0 {
			return false
		}
		return stmtAlwaysExits(s.List[len(s.List)-1])
	case *ast.ReturnStmt:
		return true
	case *ast.ExprStmt:
		if call, ok := s.X.(*ast.CallExpr); ok {
			if ident, okId := call.Fun.(*ast.Ident); okId && ident.Name == "panic" {
				return true
			}
		}
		return false
	case *ast.IfStmt:
		if s.Else != nil {
			return stmtAlwaysExits(s.Body) && stmtAlwaysExits(s.Else)
		}
		return false
	default:
		return false
	}
}

func nodeContains(container ast.Node, targetPos token.Pos) bool {
	if container == nil || !targetPos.IsValid() || !container.Pos().IsValid() || !container.End().IsValid() {
		return false
	}
	return container.Pos() <= targetPos && targetPos < container.End()
}

func checkCondition(errorsPkgName string, condExpr ast.Expr, varName string, expectNotEqualsNil bool) bool {
	switch expr := condExpr.(type) {
	case *ast.ParenExpr: // e.g. (err != nil)
		return checkCondition(errorsPkgName, expr.X, varName, expectNotEqualsNil)
	case *ast.UnaryExpr: // e.g. !(err == nil)
		if expr.Op == token.NOT {
			// If we expect 'varName != nil', then check if sub-expression implies 'varName == nil'
			// If we expect 'varName == nil', then check if sub-expression implies 'varName != nil'
			return checkCondition(errorsPkgName, expr.X, varName, !expectNotEqualsNil)
		}
	case *ast.CallExpr:
		if isErrorsIsCall(expr, errorsPkgName) {
			// errors.Is(err, something) is roughly equal to err != something
			return expectNotEqualsNil
		}
	case *ast.BinaryExpr:
		// Direct check: varName op nil or nil op varName
		xName, _ := getIdentName(expr.X)
		yName, _ := getIdentName(expr.Y)
		isDirect := xName == varName || yName == varName
		isNil := xName == "nil" || yName == "nil"
		isLikelyNonNil := strings.HasPrefix(xName, "err") || strings.HasPrefix(xName, "Err") ||
			strings.HasPrefix(yName, "err") || strings.HasPrefix(yName, "Err")

		if isDirect {
			// varName OP something
			// something OP varName
			if isNil {
				// varName OP nil
				// nil OP varName
				if expectNotEqualsNil { // Expecting varName != nil
					return expr.Op == token.NEQ
				}
				// Expecting varName == nil
				return expr.Op == token.EQL
			} else if isLikelyNonNil {
				// varName OP non-nil
				// non-nil OP varName
				if expectNotEqualsNil { // Expecting varName != nil
					return expr.Op == token.EQL
				}
				// Expecting varName == nil
				return expr.Op == token.NEQ
			}
		} else {
			// Compound AND: if varName's state is asserted as part of an AND chain
			if expectNotEqualsNil {
				// looking for err != nil, so check err != nil && <other stuff>
				if expr.Op == token.LAND { // &&
					// If either side of && asserts the desired state, then the && as a whole (if true) asserts it.
					// e.g. if looking for (err != nil) and cond is (A && err != nil), this should be true.
					return checkCondition(errorsPkgName, expr.X, varName, expectNotEqualsNil) ||
						checkCondition(errorsPkgName, expr.Y, varName, expectNotEqualsNil)
				}
			} else {
				// we're looking for err == nil || <other stuff>
				if expr.Op == token.LOR { // ||
					return checkCondition(errorsPkgName, expr.X, varName, expectNotEqualsNil) ||
						checkCondition(errorsPkgName, expr.Y, varName, expectNotEqualsNil)
				}
			}
		}
	}
	return false
}

func getIdentName(expr ast.Expr) (string, bool) {
	if ident, ok := expr.(*ast.Ident); ok {
		return ident.Name, true
	}
	return "", false
}
